import os
import os.path as osp
import pickle

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from omegaconf import OmegaConf
from sklearn.manifold import TSNE
from umap import UMAP


def preprocess_traj(traj, h=16):
    trajs = []
    for i in range(0, len(traj["actions"]) - h):
        action = traj["actions"][i : i + h].reshape(-1)
        observation = traj["observations"][i : i + h].reshape(-1)

        trajs.append(np.concatenate([action, observation]))
    return trajs


def plot_tsne(task_name, expert_trajs, diffgro_trajs, guide_hue, p=50):
    tsne = TSNE(n_components=2, perplexity=p, random_state=0)
    trajs = [expert_trajs] + [diffgro_trajs["guide"]]
    # trajs = [diffgro_trajs["guide"]]
    trajs_tsne = tsne.fit_transform(np.concatenate(trajs, axis=0))

    expert_tsne = trajs_tsne[: len(expert_trajs)]
    guide_tsne = trajs_tsne[len(expert_trajs) :]
    # guide_tsne = trajs_tsne
    x_min, x_max = np.min(trajs_tsne[:, 0]), np.max(trajs_tsne[:, 0])
    y_min, y_max = np.min(trajs_tsne[:, 1]), np.max(trajs_tsne[:, 1])

    num_bins = 8
    bin_size = (np.max(guide_hue) - np.min(guide_hue)) / num_bins
    for i in range(num_bins):
        plt.figure(figsize=(10, 10))
        sns.kdeplot(
            x=expert_tsne[:, 0],
            y=expert_tsne[:, 1],
            fill=True,
            bw_adjust=0.5,
            color="orange",
            alpha=0.5,
            zorder=1,
        )
        plt.xlim(x_min - 15, x_max + 15)
        plt.ylim(y_min - 15, y_max + 15)
        plt.xticks([])
        plt.yticks([])
        plt.legend([], [], frameon=False)
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

        end = np.min(guide_hue) + (i + 1) * bin_size
        idx = np.argwhere(guide_hue < end).flatten()

        print(f"Guide Hue: 0 ~ {end:.3f}, Num: {len(idx)}")
        sns.kdeplot(
            x=guide_tsne[idx, 0],
            y=guide_tsne[idx, 1],
            fill=True,
            bw_adjust=0.25 + 0.75 * len(idx) / len(guide_tsne),
            alpha=0.5,
            zorder=-1,
        )
        plt.xlim(x_min - 20, x_max + 20)
        plt.ylim(y_min - 20, y_max + 20)
        plt.xticks([])
        plt.yticks([])
        plt.legend([], [], frameon=False)
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.savefig(f"visualizations/tsne_{task_name}_{end:.3f}_{p}.png")

    plt.figure(figsize=(10, 10))
    sns.kdeplot(
        x=expert_tsne[:, 0],
        y=expert_tsne[:, 1],
        fill=True,
        bw_adjust=0.5,
        color="orange",
        alpha=0.7,
    )
    plt.xlim(x_min - 15, x_max + 15)
    plt.ylim(y_min - 15, y_max + 15)
    plt.xticks([])
    plt.yticks([])
    plt.legend([], [], frameon=False)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig(f"visualizations/tsne_{task_name}_expert_{p}.png")


def visualize(args):
    # expert_path = osp.join("datasets", args.domain_name, args.task_name, "trajectory")
    expert_path = osp.join("datasets", "metaworld_T1", args.task_name, "trajectory")
    expert_trajs = []
    for filename in os.listdir(expert_path):
        with open(osp.join(expert_path, filename), "rb") as f:
            traj = pickle.load(f)
            expert_trajs += preprocess_traj(traj)

    expert_trajs = np.stack(expert_trajs, axis=0)

    diffgro_path = osp.join("visualizations", args.domain_name, args.task_name)
    diffgro_trajs = {}
    for guide in [
        # "x_faster",
        # "x_slower",
        # "y_faster",
        # "y_slower",
        "faster",
        # "slower",
        # "default",
    ]:
        trajs = []
        for filename in os.listdir(osp.join(diffgro_path, guide)):
            with open(osp.join(diffgro_path, guide, filename), "rb") as f:
                traj = pickle.load(f)
                temp = preprocess_traj(traj)
                if guide != "default":
                    delta = float(filename.split("_")[0])
                    # if delta > 0.2:
                    #     continue
                    temp = [np.concatenate([t, [delta]]) for t in temp]
                trajs += temp

        if trajs:
            trajs = np.stack(trajs, axis=0)
            diffgro_trajs[guide] = trajs

    print("expert Trajs:", expert_trajs.shape)
    for guide, trajs in diffgro_trajs.items():
        print(f"{guide} Trajs:", trajs.shape)

    expert_trajs = np.random.permutation(expert_trajs)[:10000]
    diffgro_trajs["guide"] = np.concatenate(
        [diffgro_trajs[guide] for guide in diffgro_trajs if guide != "default"],
        axis=0,
    )
    diffgro_trajs["guide"] = np.random.permutation(diffgro_trajs["guide"])[:15000]

    diffgro_trajs["guide"], guide_hue = (
        diffgro_trajs["guide"][:, :-1],
        diffgro_trajs["guide"][:, -1],
    )
    for guide, trajs in diffgro_trajs.items():
        print(f"{guide} Trajs:", trajs.shape)

    for p in [20]:
        plot_tsne(args.task_name, expert_trajs, diffgro_trajs, guide_hue, p=p)
    # plot_umap(args.task_name, expert_trajs, diffgro_trajs, guide_hue)


if __name__ == "__main__":
    sns.set_style(style="white")
    args = OmegaConf.create(
        {
            "domain_name": "metaworld",
            "task_name": "button-press-variant-v2",
            # "task_name": "drawer-close-variant-v2",
            "seed": 777,
        }
    )
    visualize(args)
